{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.697, 0.46 , 0.   ],\n",
       "        [1.   , 0.   , 1.   , 0.   , 0.   , 0.   , 0.774, 0.376, 0.   ],\n",
       "        [1.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.634, 0.264, 0.   ],\n",
       "        [0.   , 1.   , 0.   , 0.   , 1.   , 1.   , 0.403, 0.237, 0.   ],\n",
       "        [1.   , 1.   , 0.   , 1.   , 1.   , 1.   , 0.481, 0.149, 0.   ],\n",
       "        [0.   , 2.   , 2.   , 0.   , 2.   , 1.   , 0.243, 0.267, 1.   ],\n",
       "        [2.   , 1.   , 1.   , 1.   , 0.   , 0.   , 0.657, 0.198, 1.   ],\n",
       "        [1.   , 1.   , 0.   , 0.   , 1.   , 1.   , 0.36 , 0.37 , 1.   ],\n",
       "        [2.   , 0.   , 0.   , 2.   , 2.   , 0.   , 0.593, 0.042, 1.   ],\n",
       "        [0.   , 0.   , 1.   , 1.   , 1.   , 0.   , 0.719, 0.103, 1.   ]]),\n",
       " array([[0.   , 0.   , 1.   , 0.   , 0.   , 0.   , 0.608, 0.318, 0.   ],\n",
       "        [2.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.556, 0.215, 0.   ],\n",
       "        [1.   , 1.   , 0.   , 0.   , 1.   , 0.   , 0.437, 0.211, 0.   ],\n",
       "        [1.   , 1.   , 1.   , 1.   , 1.   , 0.   , 0.666, 0.091, 1.   ],\n",
       "        [2.   , 2.   , 2.   , 2.   , 2.   , 0.   , 0.245, 0.057, 1.   ],\n",
       "        [2.   , 0.   , 0.   , 2.   , 2.   , 1.   , 0.343, 0.099, 1.   ],\n",
       "        [0.   , 1.   , 0.   , 1.   , 0.   , 0.   , 0.639, 0.161, 1.   ]]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据集\n",
    "def load_data():\n",
    "    with open('连续值数据.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 9), dtype=float)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line\n",
    "\n",
    "    test_x = x[10:]\n",
    "    x = x[:10]\n",
    "\n",
    "    return x, test_x\n",
    "\n",
    "\n",
    "x, test_x = load_data()\n",
    "x, test_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算数据集的熵,当然这个熵是针对于y来说的\n",
    "def get_entropy(_x):\n",
    "    entropy = 0\n",
    "\n",
    "    #统计y的熵就可以了\n",
    "    _y = _x[:, -1]\n",
    "    _y = _y.astype(np.int)\n",
    "\n",
    "    #统计每个结果出现的次数,[8 9],表示0出现8次,1出现9次\n",
    "    bincount = np.bincount(_y)  #[8 9]\n",
    "    for count in bincount:\n",
    "        if count == 0:\n",
    "            continue\n",
    "\n",
    "        #出现次数 / 总次数 = 出现概率\n",
    "        prob = count / len(_x)\n",
    "\n",
    "        #熵 = p * log(p) * -1\n",
    "        entropy -= prob * np.log2(prob)\n",
    "\n",
    "    return entropy\n",
    "\n",
    "\n",
    "#数字越大,混乱度越高\n",
    "get_entropy(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.3015, 0.3815, 0.442 , 0.537 , 0.6135, 0.6455, 0.677 , 0.708 ,\n",
       "       0.7465])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_cut_point(values):\n",
    "    values = np.unique(values)\n",
    "    values.sort()\n",
    "    cut_point = np.empty(len(values) - 1)\n",
    "    for i in range(len(values) - 1):\n",
    "        cut_point[i] = (values[i] + values[i + 1]) / 2\n",
    "    return cut_point\n",
    "\n",
    "\n",
    "get_cut_point(x[:, 6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.23645279766002802, 0.3815)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_gain(_x, col):\n",
    "    #列熵\n",
    "    min_cut_entropy = 1e20\n",
    "    min_cut = 0\n",
    "\n",
    "    #遍历所有切分点\n",
    "    for cut in get_cut_point(_x[:, col]):\n",
    "        cut_entropy = 0\n",
    "\n",
    "        #切分数据\n",
    "        for sub_x in [_x[_x[:, col] >= cut], _x[_x[:, col] < cut]]:\n",
    "\n",
    "            #这个数据子集出现的概率,很显然,等于出现次数/总次数\n",
    "            prob = len(sub_x) / len(_x)\n",
    "\n",
    "            #求数据子集的熵\n",
    "            entropy = get_entropy(sub_x)\n",
    "\n",
    "            #这个切分点的熵,等于这个式子的累计\n",
    "            cut_entropy += prob * entropy\n",
    "\n",
    "        if cut_entropy < min_cut_entropy:\n",
    "            min_cut_entropy = cut_entropy\n",
    "            min_cut = cut\n",
    "\n",
    "    #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好\n",
    "    gain = get_entropy(_x) - min_cut_entropy\n",
    "\n",
    "    return gain, min_cut\n",
    "\n",
    "\n",
    "get_gain(x, 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 1.5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_split_col(_x):\n",
    "    best_col = -1\n",
    "    best_cut = -1\n",
    "    best_gain = 0\n",
    "\n",
    "    #遍历所有的列,最后一列是y,不需要计算\n",
    "    for col in range(_x.shape[1] - 1):\n",
    "\n",
    "        #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好\n",
    "        gain, cut = get_gain(_x, col)\n",
    "\n",
    "        #信息增益最大的列,就是要被拆分的列\n",
    "        if gain > best_gain:\n",
    "            best_gain = gain\n",
    "            best_col = col\n",
    "            best_cut = cut\n",
    "\n",
    "    return best_col, best_cut\n",
    "\n",
    "\n",
    "get_split_col(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0 value=1.000000\n",
      "Leaf y=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#创建节点和叶子对象,用来构建树\n",
    "class Node():\n",
    "    def __init__(self, col, value):\n",
    "        self.col = col\n",
    "        self.value = value\n",
    "        self.gt = None\n",
    "        self.lt = None\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Node col=%d value=%f' % (self.col, self.value)\n",
    "\n",
    "    def set_child(self, symbol, node):\n",
    "        if symbol == 'gt':\n",
    "            self.gt = node\n",
    "        if symbol == 'lt':\n",
    "            self.lt = node\n",
    "\n",
    "\n",
    "class Leaf():\n",
    "    def __init__(self, y):\n",
    "        self.y = y\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Leaf y=%d' % self.y\n",
    "\n",
    "\n",
    "print(Node(0, 1.0)), print(Leaf(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 value=1.000000 \n"
     ]
    }
   ],
   "source": [
    "#打印树的方法\n",
    "def print_tree(node, prefix='', subfix=''):\n",
    "    prefix += '-' * 4\n",
    "    print(prefix, node, subfix)\n",
    "    if isinstance(node, Leaf):\n",
    "        return\n",
    "\n",
    "    if node.gt:\n",
    "        subfix = 'value=gt'\n",
    "        print_tree(node.gt, prefix, subfix)\n",
    "    if node.lt:\n",
    "        subfix = 'value=lt'\n",
    "        print_tree(node.lt, prefix, subfix)\n",
    "\n",
    "\n",
    "print_tree(Node(0, 1.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 1.5)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#先在所有数据上求最大信息增益的列,结果是7\n",
    "col,value = get_split_col(x)\n",
    "col,value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0 value=1.500000\n"
     ]
    }
   ],
   "source": [
    "#根据上面的结果,创建根节点,根节点根据列7的值来分割数据\n",
    "root = Node(0, 1.5)\n",
    "print(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 value=1.500000 \n",
      "-------- Leaf y=1 value=gt\n",
      "-------- Node col=6 value=0.381500 value=lt\n"
     ]
    }
   ],
   "source": [
    "#添加子节点的方法\n",
    "def create_children(_x, parent_node):\n",
    "\n",
    "    #遍历父节点col列所有的取值\n",
    "    for symbol in ('gt', 'lt'):\n",
    "        #首先根据父节点col列的取值分割数据\n",
    "        sub_x = _x[_x[:, parent_node.col] >= parent_node.value]\n",
    "        if symbol == 'lt':\n",
    "            sub_x = _x[_x[:, parent_node.col] < parent_node.value]\n",
    "\n",
    "        #取去重y值\n",
    "        unique_y = np.unique(sub_x[:, -1])\n",
    "\n",
    "        #如果所有的y都是一样的,说明是个叶子节点\n",
    "        if len(unique_y) == 1:\n",
    "            parent_node.set_child(symbol, Leaf(unique_y[0]))\n",
    "            continue\n",
    "\n",
    "        #否则,是个分支节点,计算最佳切分列\n",
    "        split_col, value = get_split_col(sub_x)\n",
    "\n",
    "        #添加分支节点到父节点上\n",
    "        parent_node.set_child(symbol, Node(col=split_col, value=value))\n",
    "\n",
    "\n",
    "create_children(x, root)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 value=1.500000 \n",
      "-------- Leaf y=1 value=gt\n",
      "-------- Node col=6 value=0.381500 value=lt\n",
      "------------ Node col=7 value=0.126000 value=gt\n",
      "------------ Leaf y=1 value=lt\n"
     ]
    }
   ],
   "source": [
    "#继续创建,0<1.5节点的下一层\n",
    "x_0_lt_15 = x[x[:, 0] < 1.5]\n",
    "create_children(x_0_lt_15, root.lt)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 value=1.500000 \n",
      "-------- Leaf y=1 value=gt\n",
      "-------- Node col=6 value=0.381500 value=lt\n",
      "------------ Node col=7 value=0.126000 value=gt\n",
      "---------------- Leaf y=0 value=gt\n",
      "---------------- Leaf y=1 value=lt\n",
      "------------ Leaf y=1 value=lt\n"
     ]
    }
   ],
   "source": [
    "#继续创建,7>0.126,1>0.5节点的下一层\n",
    "x_0_lt_15_and_6_gt_03815 = x_0_lt_15[x_0_lt_15[:, 6] >= 0.3815]\n",
    "create_children(x_0_lt_15_and_6_gt_03815, root.lt.gt)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "-------------------------\n",
      "0.7142857142857143\n"
     ]
    }
   ],
   "source": [
    "#预测方法,测试\n",
    "def pred(_x, node):\n",
    "    if isinstance(node, Leaf):\n",
    "        return node.y\n",
    "\n",
    "    child = node.gt\n",
    "    if _x[node.col] < node.value:\n",
    "        child = node.lt\n",
    "\n",
    "    return pred(_x, child)\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for i in x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(x))\n",
    "\n",
    "print('-------------------------')\n",
    "\n",
    "correct = 0\n",
    "for i in test_x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(test_x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
